{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-h05_PNNhZ-D"
   },
   "source": [
    "# Working with Pytrees\n",
    "\n",
    "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb)\n",
    "\n",
    "*Author: Vladimir Mikulik*\n",
    "\n",
    "Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures. In JAX, we refer to these as *pytrees*, but you can sometimes see them called *nests*, or just *trees*.\n",
    "\n",
    "JAX has built-in support for such objects, both in its library functions as well as through the use of functions from [`jax.tree_utils`](https://jax.readthedocs.io/en/latest/jax.tree_util.html) (with the most common ones also available as `jax.tree_*`). This section will explain how to use them, give some useful snippets and point out common gotchas."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9UjxVY9ulSCn"
   },
   "source": [
    "## What is a pytree?\n",
    "\n",
    "As defined in the [JAX pytree docs](https://jax.readthedocs.io/en/latest/pytrees.html):\n",
    "\n",
    "> a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything that’s not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.\n",
    "\n",
    "Some example pytrees:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "executionInfo": {
     "elapsed": 11002,
     "status": "ok",
     "timestamp": 1692698031720,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": -60
    },
    "id": "Wh6BApZ9lrR1",
    "outputId": "df1fa4cd-88a6-4d71-a376-b2ddf91568dd"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1, 'a', <object object at 0x7fada8e9e9d0>]   has 3 leaves: [1, 'a', <object object at 0x7fada8e9e9d0>]\n",
      "(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]\n",
      "[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]\n",
      "{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]\n",
      "Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]\n"
     ]
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "example_trees = [\n",
    "    [1, 'a', object()],\n",
    "    (1, (2, 3), ()),\n",
    "    [1, {'k1': 2, 'k2': (3, 4)}, 5],\n",
    "    {'a': 2, 'b': (2, 3)},\n",
    "    jnp.array([1, 2, 3]),\n",
    "]\n",
    "\n",
    "# Let's see how many leaves they have:\n",
    "for pytree in example_trees:\n",
    "  leaves = jax.tree_util.tree_leaves(pytree)\n",
    "  print(f\"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_tWkkGNwW8vf"
   },
   "source": [
    "We've also introduced our first `jax.tree_*` function, which allowed us to extract the flattened leaves from the trees."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RcsmneIGlltm"
   },
   "source": [
    "## Why pytrees?\n",
    "\n",
    "In machine learning, some places where you commonly find pytrees are:\n",
    "* Model parameters\n",
    "* Dataset entries\n",
    "* RL agent observations\n",
    "\n",
    "They also often arise naturally when working in bulk with datasets (e.g., lists of lists of dicts)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sMrSGSIJn9MD"
   },
   "source": [
    "## Common pytree functions\n",
    "Perhaps the most commonly used pytree function is `jax.tree_map`. It works analogously to Python's native `map`, but on entire pytrees:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wZRcuQu4n7o5",
    "outputId": "3528bc9f-54ed-49c8-b79a-1cbea176c0f3"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[2, 4, 6], [2, 4], [2, 4, 6, 8]]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list_of_lists = [\n",
    "    [1, 2, 3],\n",
    "    [1, 2],\n",
    "    [1, 2, 3, 4]\n",
    "]\n",
    "\n",
    "jax.tree_map(lambda x: x*2, list_of_lists)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xu8X3fk4orC9"
   },
   "source": [
    "`jax.tree_map` also works with multiple arguments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KVpB4r1OkeUK",
    "outputId": "33f88a7e-aac7-48cd-d207-2c531cd37733"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[2, 4, 6], [2, 4], [2, 4, 6, 8]]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "another_list_of_lists = list_of_lists\n",
    "jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dkRKy3LvowAb"
   },
   "source": [
    "When using multiple arguments with `jax.tree_map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Lla4hDW6sgMZ"
   },
   "source": [
    "## Example: ML model parameters\n",
    "\n",
    "A simple example of training an MLP displays some ways in which pytree operations come in useful:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "j2ZUzWx8tKB2"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def init_mlp_params(layer_widths):\n",
    "  params = []\n",
    "  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):\n",
    "    params.append(\n",
    "        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),\n",
    "             biases=np.ones(shape=(n_out,))\n",
    "            )\n",
    "    )\n",
    "  return params\n",
    "\n",
    "params = init_mlp_params([1, 128, 128, 1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kUFwJOspuGvU"
   },
   "source": [
    "We can use `jax.tree_map` to check that the shapes of our parameters are what we expect:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ErWsXuxXse-z",
    "outputId": "d3e549ab-40ef-470e-e460-1b5939d9696f"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'biases': (128,), 'weights': (1, 128)},\n",
       " {'biases': (128,), 'weights': (128, 128)},\n",
       " {'biases': (1,), 'weights': (128, 1)}]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_map(lambda x: x.shape, params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zQtRKaj4ua6-"
   },
   "source": [
    "Now, let's train our MLP:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "iL4GvW9OuZ-X"
   },
   "outputs": [],
   "source": [
    "def forward(params, x):\n",
    "  *hidden, last = params\n",
    "  for layer in hidden:\n",
    "    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])\n",
    "  return x @ last['weights'] + last['biases']\n",
    "\n",
    "def loss_fn(params, x, y):\n",
    "  return jnp.mean((forward(params, x) - y) ** 2)\n",
    "\n",
    "LEARNING_RATE = 0.0001\n",
    "\n",
    "@jax.jit\n",
    "def update(params, x, y):\n",
    "\n",
    "  grads = jax.grad(loss_fn)(params, x, y)\n",
    "  # Note that `grads` is a pytree with the same structure as `params`.\n",
    "  # `jax.grad` is one of the many JAX functions that has\n",
    "  # built-in support for pytrees.\n",
    "\n",
    "  # This is handy, because we can apply the SGD update using tree utils:\n",
    "  return jax.tree_map(\n",
    "      lambda p, g: p - LEARNING_RATE * g, params, grads\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "B3HniT9-xohz",
    "outputId": "d77e9811-373e-45d6-ccbe-edb6f43120d7"
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhYAAAGdCAYAAABO2DpVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAABGgUlEQVR4nO3deVxU5f4H8M/MsAsziopAouKShpiliYFpaaaomd4Wu5ZrXksTzXu7JnZ/XfKaqbflWmpo3lLLzLqZa4aZuW9YZEm4i2aKoiKLC9uc8/vjOMgwM3DOcGbl8369eBnDOTNPZPCZ5/k+30cjiqIIIiIiIhVoXT0AIiIi8h4MFkRERKQaBgsiIiJSDYMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWp8nP2CgiDg/PnzCAkJgUajcfbLExERkR1EUURRUREiIyOh1dqel3B6sDh//jyioqKc/bJERESkgrNnz6Jp06Y2v+70YBESEgJAGpher3f2yxMREZEdCgsLERUVVfF73BanBwvT8oder2ewICIi8jA1lTGweJOIiIhUw2BBREREqmGwICIiItU4vcZCDqPRiLKyMlcPgwg6nQ4+Pj7cGk1EJJPbBYtr167hjz/+gCiKrh4KEQAgKCgIERER8PPzc/VQiIjcnlsFC6PRiD/++ANBQUFo3Lgx3yWSS4miiNLSUly6dAnZ2dlo06ZNtU1hiIjIzYJFWVkZRFFE48aNERgY6OrhECEwMBC+vr44c+YMSktLERAQ4OohERG5Nbd8+8WZCnInnKUgIpLPrWYsiIiIyD5GQUR6dh5yi4oRFhKAuOhQ6LTOf6POt2IeYtu2bdBoNMjPz5d9T4sWLTB37lyHjUmphx56CJMnT674XI3xudu/IxGRK6Rl5uCBOT9g6OJ9eGnlQQxdvA8PzPkBaZk5Th8Lg4UKRo0aBY1Gg3Hjxll8bcKECdBoNBg1apTzB+bmDhw4gOeff17WtUuXLkX9+vVr9RxERN4oLTMH45dnIKeg2OzxCwXFGL88w+nhgsFCJVFRUVi5ciVu3rxZ8VhxcTFWrFiBZs2auXBk6iotLVXtuRo3boygoCCXPwcRkacyCiKmr8+CtQYNpsemr8+CUXBeCwfvDBaCEcjeCRz6SvpTMDr8JTt16oSoqCh8/fXXFY99/fXXaNasGe69916za0tKSjBp0iSEhYUhICAADzzwAA4cOGB2zcaNG3HnnXciMDAQPXv2xOnTpy1ec9euXejevTsCAwMRFRWFSZMm4fr167LHPGrUKAwePBjTp09H48aNodfrMW7cOLPw8NBDDyEpKQmTJ09Go0aN0LdvXwBAZmYm+vXrh+DgYDRp0gTDhw/H5cuXK+67fv06RowYgeDgYEREROCdd96xeP2qyxj5+fl44YUX0KRJEwQEBCA2NhYbNmzAtm3bMHr0aBQUFECj0UCj0eD111+3+hy///47Bg0ahODgYOj1egwZMgQXL16s+Prrr7+Oe+65B59++ilatGgBg8GAP//5zygqKpL9fSMichfp2XkWMxWViQByCoqRnp3ntDF5X7DIWgfMjQWWPQqsGiP9OTdWetzBnnvuOSxZsqTi848//hijR4+2uO6VV17BqlWrsGzZMmRkZKB169bo27cv8vKk//Bnz57F448/joEDB+LgwYP4y1/+guTkZLPnOHnyJBITE/HEE0/g119/xRdffIFdu3YhKSlJ0Zi3bNmCw4cPY9u2bfj888/x9ddfY/r06WbXLFu2DH5+fti9ezcWLlyI/Px89OrVC/feey9+/PFHpKWl4eLFixgyZEjFPVOmTMH27duxdu1afPfdd9i2bRsyMjJsjkMQBPTr1w+7d+/G8uXLkZWVhdmzZ0On0yEhIQFz586FXq9HTk4OcnJy8Pe//93qcwwaNAh5eXnYvn07Nm/ejFOnTuHpp5+2+N6tWbMGGzZswIYNG7B9+3bMnj1b0feNiMgd5BbZDhX2XKcG79oVkrUO+HIEUHVSqDBHenzIJ0DMYw57+WHDhmHatGk4c+YMAGD37t1YuXIltm3bVnHN9evXkZqaiqVLl6Jfv34AgMWLF2Pz5s346KOPMGXKFKSmpqJVq1YV7/Lbtm2LQ4cOYc6cORXPM2vWLDz77LMVxZBt2rTB+++/jwcffBCpqamy+y34+fnh448/RlBQENq3b49//etfmDJlCmbMmFGxzbJNmzb497//XXHPG2+8gXvvvRdvvvlmxWMff/wxoqKicOzYMURGRuKjjz7C8uXL8fDDDwOQwknTpk1tjuP7779Heno6Dh8+jDvvvBMA0LJly4qvGwwGaDQahIeH23yOLVu24NChQ8jOzkZUVBQA4JNPPkH79u1x4MABdOnSBYAUQJYuXYqQkBAAwPDhw7FlyxbMnDlT1veMiMhdhIXI+1kv9zo1eE+wEIxA2lRYhArg1mMaIC0ZaDcA0OocMoTGjRtjwIABWLp0KURRxIABA9CoUSOza06ePImysjJ069at4jFfX1/ExcXh8OHDAIDDhw+ja9euZvfFx8ebff7LL7/g119/xWeffVbxmCiKEAQB2dnZuOuuu2SNuWPHjmY1CvHx8bh27RrOnj2L5s2bAwA6d+5s8dpbt25FcHCwxfOdPHkSN2/eRGlpqdm/Q2hoKNq2bWtzHAcPHkTTpk0rQoU9Dh8+jKioqIpQAQAxMTGoX78+Dh8+XBEsWrRoUREqACAiIgK5ubl2vy4RkavERYciwhCACwXFVn/7aQCEG6Stp87iPcHizB6g8Hw1F4hA4TnpuujuDhvGc889V7EcsWDBAoe9zrVr1/DCCy9g0qRJFl9Tu1i0Xr16Fq89cOBAsxkUk4iICJw4cULxaziz06qvr6/Z5xqNBoIgOO31iYjUotNqkDIwBuOXZ0AD87fWpg4WKQNjnNrPwntqLK5drPkaJdfZKTExEaWlpSgrK6sodKysVatWFfUKJmVlZThw4ABiYmIAAHfddRfS09PN7tu3b5/Z5506dUJWVhZat25t8aHksKxffvnFbCfLvn37EBwcbPauv6pOnTrht99+Q4sWLSxeu169emjVqhV8fX2xf//+inuuXr2KY8eO2XzOu+++G3/88YfNa/z8/GA0Vl+Ee9ddd+Hs2bM4e/ZsxWNZWVnIz8+v+N4SEXmbxNgIpA7rhHCD+XJHuCEAqcM6ITE2wqnj8Z5gEdxE3evspNPpcPjwYWRlZUGns1xyqVevHsaPH48pU6YgLS0NWVlZGDt2LG7cuIExY8YAAMaNG4fjx49jypQpOHr0KFasWIGlS5eaPc/UqVOxZ88eJCUl4eDBgzh+/DjWrl2ruHiztLQUY8aMQVZWFjZu3IiUlBQkJSVV28Z6woQJyMvLw9ChQ3HgwAGcPHkSmzZtwujRo2E0GhEcHIwxY8ZgypQp+OGHH5CZmYlRo0ZV+5wPPvggevTogSeeeAKbN29GdnY2vv32W6SlpQGQli+uXbuGLVu24PLly7hx44bFc/Tu3RsdOnTAs88+i4yMDKSnp2PEiBF48MEHcd999yn6vhAReZLE2AjsmtoLn4+9H+/9+R58PvZ+7Jray+mhAvCmYNE8AdBH4vbkT1UaQH+HdJ2D6fV66PV6m1+fPXs2nnjiCQwfPhydOnXCiRMnsGnTJjRo0ACAtJSxatUqrFmzBh07dsTChQvNCiUB6R3+9u3bcezYMXTv3h333nsv/vnPfyIyMlLRWB9++GG0adMGPXr0wNNPP43HHnusYiunLZGRkdi9ezeMRiP69OmDDh06YPLkyahfv35FeHjrrbfQvXt3DBw4EL1798YDDzxgUatR1apVq9ClSxcMHToUMTExeOWVVypmKRISEjBu3Dg8/fTTaNy4sVkxqYlGo8HatWvRoEED9OjRA71790bLli3xxRdfKPqeEBF5Ip1Wg/hWDTHonjsQ36qhS9p5A4BGFEXndc0AUFhYCIPBgIKCAotfvsXFxcjOzkZ0dLR9p0hW7AoBrK40OXhXiKcZNWoU8vPzsWbNGlcPxa3V+u8lEZEXqO73d2XeM2MBSKFhyCeAvsrUjz6SoYKIiMgJvGdXiEnMY9KW0jN7pELN4CbS8oeDtpgSERHRbd4XLAApRDhwS6m3qFoQSkREVFvetRRCRERELsVgQURERKpxy2Dh5I0qRNXi30ciIvncKliYGkpVPrabyNVMzbiqtgInIiJLblW86ePjg6CgIFy6dAm+vr7VdmokcjRRFHHjxg3k5uaifv36VjupEhGRObcKFhqNBhEREcjOzq44epzI1erXr1/tce1ERHSbWwULQDpsqk2bNlwOIbfg6+vLmQoiIgXcLlgAgFarZetkIiIiD8QiBiIiIlINgwURERGphsGCiIiIVMNgQURERKphsCAiIiLVMFgQERGRahgsiIiISDUMFkRERKQaBgsiIiJSDYMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWoYLIiIiEg1DBZERESkGgYLIiIiUg2DBREREamGwYKIiIhUw2BBREREqmGwICIiItUwWBAREZFqGCyIiIhINYqChdFoxGuvvYbo6GgEBgaiVatWmDFjBkRRdNT4iIiIyIP4KLl4zpw5SE1NxbJly9C+fXv8+OOPGD16NAwGAyZNmuSoMRIREZGHUBQs9uzZg0GDBmHAgAEAgBYtWuDzzz9Henq6QwZHREREnkXRUkhCQgK2bNmCY8eOAQB++eUX7Nq1C/369bN5T0lJCQoLC80+iIiIyDspmrFITk5GYWEh2rVrB51OB6PRiJkzZ+LZZ5+1ec+sWbMwffr0Wg+UiIiI3J+iGYsvv/wSn332GVasWIGMjAwsW7YMb7/9NpYtW2bznmnTpqGgoKDi4+zZs7UeNBEREbknjahgS0dUVBSSk5MxYcKEisfeeOMNLF++HEeOHJH1HIWFhTAYDCgoKIBer1c+YiIiInI6ub+/Fc1Y3LhxA1qt+S06nQ6CINg3SiIiIvIqimosBg4ciJkzZ6JZs2Zo3749fv75Z7z77rt47rnnHDU+IiIi8iCKlkKKiorw2muvYfXq1cjNzUVkZCSGDh2Kf/7zn/Dz85P1HFwKISIi8jxyf38rChZqYLAgIiLyPA6psSAiIiKqDoMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWoYLIiIiEg1DBZERESkGgYLIiIiUg2DBREREamGwYKIiIhUw2BBREREqmGwICIiItUwWBAREZFqGCyIiIhINQwWREREpBoGCyIiIlINgwURERGphsGCiIiIVMNgQURERKphsCAiIiLVMFgQERGRahgsiIiISDUMFkRERKQaBgsiIiJSDYMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWoYLIiIiEg1DBZERESkGgYLIiIiUg2DBREREamGwYKIiIhUw2BBREREqmGwICIiItUwWBAREZFqGCyIiIhINQwWREREpBoGCyIiIlINgwURERGphsGCiIiIVMNgQURERKphsCAiIiLVMFgQERGRahgsiIiISDUMFkRERKQaBgsiIiJSDYMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWoYLIiIiEg1DBZERESkGgYLIiIiUg2DBREREalGcbA4d+4chg0bhoYNGyIwMBAdOnTAjz/+6IixERERkVyCEcjeCRz6SvpTMLpkGD5KLr569Sq6deuGnj174ttvv0Xjxo1x/PhxNGjQwFHjIyIioppkrQPSpgKF528/po8EEucAMY85dSiKgsWcOXMQFRWFJUuWVDwWHR2t+qCIiIhIpqx1wJcjAIjmjxfmSI8P+cSp4ULRUsi6detw33334amnnkJYWBjuvfdeLF68uNp7SkpKUFhYaPZBREREKhCM0kxF1VAB3H4sLdmpyyKKgsWpU6eQmpqKNm3aYNOmTRg/fjwmTZqEZcuW2bxn1qxZMBgMFR9RUVG1HjQREREBOLPHfPnDgggUnpOucxKNKIrWYo5Vfn5+uO+++7Bnz+0BTpo0CQcOHMDevXut3lNSUoKSkpKKzwsLCxEVFYWCggLo9fpaDJ2IiKiOO/QVsGpMzdc98RHQ4clavVRhYSEMBkONv78VzVhEREQgJibG7LG77roLv//+u817/P39odfrzT6IiIhIBcFN1L1OBYqCRbdu3XD06FGzx44dO4bmzZurOigiIiKSoXmCtPsDGhsXaAD9HdJ1TqIoWPz1r3/Fvn378Oabb+LEiRNYsWIFPvzwQ0yYMMFR4yMiIiJbtDppSykAy3Bx6/PE2dJ1zhqSkou7dOmC1atX4/PPP0dsbCxmzJiBuXPn4tlnn3XU+IiIiKg6MY9JW0r1EeaP6yOdvtUUUFi8qQa5xR9ERESkgGCUdn9cuyjVVDRPUHWmQu7vb0UNsoiIiMhNaXVAdHdXj4KHkBEREZF6GCyIiIhINQwWREREpBoGCyIiIlINgwURERGphsGCiIiIVMNgQURERKphsCAiIiLVMFgQERGRahgsiIiISDUMFkRERKQaBgsiIiJSDYMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWoYLIiIiEg1DBZERESkGgYLIiIiUg2DBREREamGwYKIiIhU4+PqAajBKIhIz85DblExwkICEBcdCp1W4+phERER1TkeHyzSMnMwfX0WcgqKKx6LMAQgZWAMEmMjXDgyIiKiusejl0LSMnMwfnmGWagAgAsFxRi/PANpmTkuGhkREVHd5LHBwiiImL4+C6KVr5kem74+C0bB2hVERETkCB4bLNKz8yxmKioTAeQUFCM9O895gyIiIqrjPDZY5BbZDhWVXSiUdx0RERHVnscGi7CQAFnXzdjwG2stiIiInMRjg0VcdCgiDAGoaVNp3vUyFnISERE5iccGC51Wg5SBMQBQY7gAWMhJRETkDB4bLAAgMTYCqcM6oUE9v2qvYyEnERGRc3h0sACkcPHagLtkXSu34JOIiIjs4/HBAgDCDYGyrpNb8ElERET28YpgUVMhpwZSm++46FBnDouIiKjO8YpgUV0hp+nzlIExPJiMiIjIwbwiWAC3CznDDebLHeGGAKQO68QDyYiIiJzA4083rSwxNgKPxITzCHUiIiIX8apgAUjLIvGtGrp6GERERHWS1yyFEBERkesxWBAREZFqGCyIiIhINQwWREREpBoGCyIiIlINgwURERGphsGCiIiIVMNgQURERKphsCAiIiLVMFgQERGRahgsiIiISDVed1aIPYyCyIPLiIiIVFDng0VaZg6mr89CTkFxxWMRhgCkDIzhUetEREQK1emlkLTMHIxfnmEWKgAgp6AY45dnIC0zx0UjIyIi8kx1NlgYBRHT12dBtPF1EcC0rw/BKNi6goiIlDIKIvaevIK1B89h78kr/BnrhersUkh6dp7FTEVVV2+UYf4Px/FS7zudNCoiIu+VlpmD19dl4ULh7Z+94foAvP4Yl569SZ2dscgtqj5UmCzZfZqJmoioltIyczBueYZZqACAC4XFGMelZ69SZ4NFWEiArOvyb5YhPTvPwaMhIvJeRkFE8teHqr2GS8/eo84Gi7joUNQP9JV1rdzZDSIisrTv1BXk3yir9pqrN8qw79QVJ42IHKlWwWL27NnQaDSYPHmySsNxHp1Wg9HdomVdK3d2g4iILO09KS8wyL2O3JvdweLAgQNYtGgR7r77bjXH41RJvVqjfpDtWQsNpJ4WcdGhzhsUEZHXkbvEwaUQb2BXsLh27RqeffZZLF68GA0aNFB7TE6j02ow+/EOVr9m6ruZMjCGXTiJiGohvmUjVa8j92ZXsJgwYQIGDBiA3r1713htSUkJCgsLzT7cSWJsBBYO64QIg/lyR7ghAKnDOnELFBFRLd3fqmHF7LAPyvGcbiNe91mK53Qb4YNyAED9IF/c36qhK4dJKlHcx2LlypXIyMjAgQMHZF0/a9YsTJ8+XfHAnCkxNgKPxITzvBAiIgcwzQ6f/vzvGOvzDXSa20se//D5DIvLB6DF42/zZ66XUBQszp49i5deegmbN29GQIC8gsZp06bhb3/7W8XnhYWFiIqKUjZKJ9BpNYhnWiYicojE8x9A9N1g8bhWI+IF3w3QnG8JxM5wwchIbRpRFGVXy6xZswZ/+tOfoNPpKh4zGo3QaDTQarUoKSkx+5o1hYWFMBgMKCgogF6vt3/kRETkGcpLgZlNAFGwfY1GB/zjAuDj57xxkSJyf38rmrF4+OGHceiQeZOT0aNHo127dpg6dWqNocJhBCNwZg9w7SIQ3ARongBoXTQWIiIyd2Bx9aECAESjdF38BOeMiRxGUbAICQlBbGys2WP16tVDw4YNLR53mqx1QNpUoPD87cf0kUDiHCDmMdeMiYiIbrt6Wt3ryK15dufNrHXAlyPMQwUgff7lcGDbHGk2g4iIXKdBC3WvI7emqMZCDarVWAhGYG6sZaioKiQC6Pdvzl4QEbkKayy8gtzf3547Y3FmT82hAgCKcqRZjax1jh8TERFZ8vED4pOqvyZ+AkOFl/DcYHHtorLr05K5LEJE5Cp9ZgAJkwBNlV87Gp30eB9uNfUWihtkuY3gJgouFoHCc9IsR3R3hw2JiIiq0WcG0Os1affH1dNSTUWXsZyp8DKeGyyaJ0i7PwpzIPvgGqWzHEREpC4fP24p9XKeuxSi1UlbSpVQNMtBRERESnlusACknR5DPpFmLqqlAfR3SLMcRERE5DCeHSwAKVxMzgQeetXGBbcOtUmcbbsbp2AEsncCh76S/mSRJxERkV08t8aiMq0OeGgqEHaXjS6cs233sbDWuTOoETDgHaD9YIcOm4iIyNt4boMsW5ScG2Lq3Gmr+JNboIiIiAA46BAyj6DVydtSKhilmYrqdpTseR+I7AzEDlZrdERERF7N82ss7CW3c+fGl1lzQUREJFPdDRZye1rcuCyFECKiuojF7aSQ9y2FyKWkpwUbaxFRXWStuF0fKfUQ4sGOZEPdnbFoniDt/pCDjbWIqK7JXAN8OdxyybiQBztS9epusNDqpC2lNbGzsVZpuYCPdp7CP9dm4qOdp1BaXs1xwURE7uS3NcCq0Ta+eKvgnQc7kg11dykEkPpUnJsk7f6wSlN9Yy0bZm3MwuKd2RAqbTiZufEwxnaPxrT+MXYPl4jI4bLWAf8bWcNFPNiRbKu7MxYmfWYATy4DghqaP66/Q2oXrnAdcdbGLCzaYR4qAEAQgUU7sjFrY1YtB0xE5CAV2/BlYv0ZWVG3ZyxMYgcDMQPlN9ayobRcwOKd2dVes3hnNl7u0w5+Psx0RORm5G7DN2H9GVnBYGEit7FWNT7de9pipqIqQZSuG9O9Za1ei4hIdTJnIAQRyNU0ROOoeCh7+0V1Ad82q+hM3g1VryMicioFMxAppcORfqbAgYMhT8VgoaLmoUGqXkdE5FTNEwB9ZHUHHaBc1OLFsknYJMQht6jYaUMjz8FgoaLh8S2g1VR/jVYjXUdE5Ha0OvzcPhmiCItlXVGUPiaWTUSacD8AICwkwAWDJHfHYKEiPx8txnaPBgD4oBzP6TbidZ+leE63ET4oBwCM7R7Nwk0icktGQcSLGU0xvmwyLiDU7Gs5aIhxZZPxrdAVABBhCEBcdKi1p6E6jsWbKpvWPwYP/T4fcTkroNPcjvz/8PkM6RHPIL7/By4cHRGRbenZecgpKEYO4rC55D7EaY8gDPnIRX2kC+0gVHovmjIwBrqapmipTmKwUNt3ryH+wmcQq/z/ptWIiL/wGfBdQ6l3BhGRm6lcMyFAi32C9YZ+Y7q1QGJshLOGRR6Gc/JqKi8F9s4HAFTN8RWf710gXUdE5Gbk1kz0jgl38EjIkzFYqOnAYkCs4UwQ0ShdR0TkZuKiQxFhCLB4Y2SiAWsrqGYMFmq6elrRdUZBxN6TV7D24DnsPXkFxpq6axEROZBOq0HKQGn5w9asK2srqCassVBTgxayr0vLzMH09VnIKbi9phlhCEDKwBiuXRKROgSj4qMKEmMjkDqsk8XPp3D+fCKZNKIoOvVtcmFhIQwGAwoKCqDX65350o5XXgrMbFL9cohGh7Q//YzxKzItmtCY3gOkDuvE/3mJqHay1kkHilU++0MfCSTOkXW4olEQkZ6dh9yiYoSFSMsfnKmo2+T+/uZSiJp8/ID4pGovMd7/IqZ/cwIiAC0E3K/NwmPaPbhfmwUNpEAyfX0Wl0WIyH5Z64AvR1geKFaYIz2eta7Gp9BpNYhv1RCD7rkD8a0aMlSQbFwKUZtpK+ne+eYzFxodED8B6a0mI2frPvTVpiPF9xNEavIqLjkvhmJ62QhsKohDenYe4ltVOcqdiKgmFUefW3tzIgLQAGnJQLsBik9wJpKDwcIR+swAer0m7f64elqqvegyFvDxQ+7Bc+irTUeq71yL28KRh1TfuRhfNhm7T7Tm1CMRKVfj0eciUHhOuq6WJzoTWcNg4Sg+fkD8BIuHw+r5IsX3EwCwOFdEq5H686f4fooHtt6HVRl/sFiKiGpWuUjz0hF598g8Ip1IKQYLJ4vTHYGu0vJHVVoNEIkrGKVLw7KCRIxfnsFiTiKyzVqRphwKjkgnUoLFm06mu54r67p/+i7HTv9J6KtNZzEnEVlnq0izWhpAf4e09ZTIARgsnE3Bu4Rw5OED37m4u2gH0rNtz3IQUR1UbZGmLbfWXxNns3CTHIbBwtmaJ0h7yW02zb3NVIOR4vspcguvO3ZcRORZdrytfPlDHwkM+URWHwsie7HGwtm0OqlBzZcjIIWL6t9tmGouWt84BKCZM0ZIRO4uax2w7U1513afAoS1k915k6i2OGPhCjGPSe8a9PILMgOOr8dvu7+BsbzcgQMjIrdXsQQiz28B92KtMR57hRgY+SOfnIAtvV1JMAL7FwKbXpV9SwGCkRszGm2enM53HkR1UfZOYNmjNV4mAriIhkgofg/CrUDB84ioNtjS2xNodUDXcYA+EqKNmouqsc+Aa2iTNQ+ls1vKastLRF5Gbv8JEUgpHV4RKgDgQkExxi/PQFpmjoMGR8Rg4Xq3ai6kagvzcCGKgMZGjadvST5EmT3/iciLyNxZ9m75k9gkxJk9Znqfwi3s5EgMFu7gVs2FpkrNha1QYfqaKIooWTNJOlWViLyTYJSWPw59Jf0Z1bXanWUipHOHFhgH2/x6TkExt7CTw3BXiLuIeUw6FOjMHpzasQIts1fUeItWA/iXXoX4bjtoHp3LLWRE3sbW0eexTwJ75sFyZ5kUNqaXjTBbArEmt6hY9eESAZyxcC9aHRDdHTdb11yYZebGFdlHIRORh6ju6PM984CEiZY7y/SROPbgAoslEGvCQgJUHCzRbZyxcEPtuvbFxc0N0Vi8YnFQmTWm9ywaHoVM5B3kHH2euQqY9Atwdj+Eogs4XBSEE0Ed0CgkCOH6g7hYWGL1bg2AcEMA4qJDHfqvQHUXg4Ub0vn44Hx8ChrvmVRtAWdlGh6FTOQdTNvQ5Rx9fnY/0q63xvRvSpFTUAzgEACgfpCvKX5YWSgBUgbGQCfnXQuRHRgs3NS9fUfiJ0FEy33/QANck32fUHSB61tEnkrhSaW/HD6C8TsKLWYmCm6UAQAMQb7Iv/XPgDRTwT4W5GgMFm6sc79R2BjZG4e/TMHzPhsQoqm52OpwURDaO2FsRKQyU02FgkPFFv18o7rFEgT66rBgTCdcvl6CsBBp+YMzFeRofHPr5vp3bIr2Q99AT83HuCKGWDTMMhFE4LzYECeCOjh3gERUe4pPKtWgJCgCaUUtbV5h2laq1Wow6J47EN+qIUMFOQWDhQdIjI3Ae8O64tWyMRAhhYjKTJ9PLxuOMH29Ww9W2fsuGJ06ZiKSSVZNRWVSOFgXPrHGLaUAt5WS83EpxEPc37Ih/h7SAy8WAf/0/QSRuN3c5gIa4l9lw/FrSA+p0tvW3vfEOex1QeROFNZUAAD0kfi5/VRM2dpI1uXcVkrOxmDhIXRaDVIGxmD88mJsLrkPXbRHEIZ85KI+DgjtIECL1IEx0B1Zb32dtjBHenzIJwwXRO7AjpoK9H0Txi4v4MW3tgOofiaC20rJVbgU4kESYyOQOqwTwgxB2CfEYJ2QgH1CDMIMQUgd1gmJMWE17H0HkJbMZREiV7OjpgL6O4Cu45B+puDW1tLqieC2UnINzlh4mMTYCDwSE4707DzkFhWbV3pn75S39529Lohcx86aCiTOBrQ62TUTz3VrwW2l5BIMFh5Ip9UgvlVDyy/IPU555zvArneB0JbAI28AfoHqDpCILAlGYMfbwP4PgJv58u/TR0qh4tYSptyaiUdiwu0YJFHtMVh4E5nHKePUVunPkz8AB/4LtO0PDP3cceMiquuy1gHrJwE3ryq771ZNRfqZAuQePIewkAB0bt4AEYYAXCgoZstucksMFt6keQJuBobD/8YFq2eMiLe65lh86ehG4POhDBdEjpC1DvhyuMKbNIA+EmnBgzD9re1mNRURhgA81jECH+7IZstuckuKijdnzZqFLl26ICQkBGFhYRg8eDCOHj3qqLGRQkZoMb1sBADLXhemxlo2f9Qc3QiU3nTY2IjqJMEIrH9J4U3S/6U/t5+K8Z/9YlGoeaGgGB/uyMbzPaIRbjBfFgk3BEiF3KytIBdSNGOxfft2TJgwAV26dEF5eTleffVV9OnTB1lZWahXr56jxkgypWfnYeW1e3BVOxkpVXpdyDnIDJ8PBXq8DDRP4AmpRGo4vQu4mVfzdZXpI2HsOwsvrguGaGVLqald97pfcrB9Sk/8dOaqZSE3kQspChZpaWlmny9duhRhYWH46aef0KNHD1UHRsqZqsU3CXHYXHIf4m71unhO9w3u0WXX/ATZW6UPNtMiUkf2TmXX931T2lKanY+cgn02LzO16/7pzFXrhdxELlSrPhYFBQUAgNBQ20VCJSUlKCwsNPsgx6hcLS5AW9Hr4hexlbInMjXTylqn8giJ6hglkwe3+lQYocXuE5dk3cJ23eSO7A4WgiBg8uTJ6NatG2JjY21eN2vWLBgMhoqPqKgoe1+SahAXHYoIQ4DFz7KZ5cMgirB5gJklNtMislvlc3oCDPLvS5yNtKxcPDDnB8zfelLWLWzXTe7I7mAxYcIEZGZmYuXKldVeN23aNBQUFFR8nD171t6XpBqY2n4D5m+USuGH74ydAShpHlypmRYRyZP5NTCnBbDsUWDVGOC7/0PN0xZa4KllSBO6YPzyDFldNTWQdodwSym5I7u2myYlJWHDhg3YsWMHmjZtWu21/v7+8Pf3t2twpJyp7ff09VlmP6CS/ZKB0tnoo/tJ2fTsya1A9nYpkUR3B1o8wMJOIms+HyrtrrJQQ5x/cgmMdw3C9Dk/yAr+3FJK7k4jigomyEUREydOxOrVq7Ft2za0adNG8QsWFhbCYDCgoKAAer1e8f0kj1EQK9p+Nwr2x1+WHcDNMgF+KMU/fJbjHs1JdJRT0FlVYCgw8D0WdhJVtukfwN75NVxUpetESCTQTyqS3nvyCoYutl2sWVmEIQApA2O4pZScTu7vb0UzFhMmTMCKFSuwdu1ahISE4MKFCwAAg8GAwEC2hXYnldt+7zx2CTfLBADSskhK+XPQQsAu7SSEI89qMy2bbuZJzX6GfMpwQQQA5aXA3gUyLhSlXR/BTaSPW9u6jYKI3Scuy3qppJ6t8ddH7uRMBbk1RTUWqampKCgowEMPPYSIiIiKjy+++MJR4yMVrMr4w+IxoZpmWrKwsJNIcmAxZFcvBTcBOjwpLStqdUjLzLlVrHlC1u3dWjdiqCC3p2jGQsGqCbmRG6XlVh/fJMRhfJllMy1ZeEoq1XWCUfp/4MT38u+pdJ5PWmYOxi/PkF1XwfM/yFPwrJA6oEuLhvguK9fq16o202qt+QOTfNfIe2K5p6kSeZusdUDaVAVHn0Paeto8AYBUAzV9fRaLNckr1apBFnmGkQktqm3pXbmZ1h7Rdk8SC3JPUyXyJoe+luqMlIQKAHh0bkVNxdLd2bK2lQI8/4M8D2cs6gA/Hy2e7x6NRTtq3gWSLrTDebEBInC1+vNF9HdUvPsiqjM2/R+wd57y+9r2hzHmT5j//TEs2X0a+TfLZN2W1LMV/vpIW85UkEdhsKgjpvWXGmct3pldbbGmVNQ5Eqm+cwGxmsPLEmff7mdhWmu+dtGs2p3Iq3z3mh2hQgvEv4i0O5KQ/MZm5N+QFyhMurVuzFBBHkdRHws1sI+Fa5WWC/h072lsO3YJO4/b3uLWV5uOWb7/RajmmvkXqvaxsLbWHFgf6Poi0OPvDBjkHUpvArMiAVGQd32rXkDr3kCXsUg7cgXjlmcoejlTseauqb0YLMhtyP39zWBRR8lpyKOFgPfvv4aGl/YD0CC4XU/ExPeHzufWRFfWOumwMlslaGymRd4gax2wdgJQouAAxZEbgOjuMAoiOiucqTDFCNZVkLtxSIMs8h6mA8suFBTbrEwXoUXSPj2AR6QHTgL1t/6A2Y93QGJMmDRTUV1du6mZ1v0vAm37c4mEPE9N4dmaSvVH+05dUbz8Ec7OmuThuCukjrJ1YFll1n6U5t8ow7jlGUjftl5+Vfy+D6RDmebG8ih28hyCsebwbE2l+qO9J68ouvW1AXdh19ReDBXk0Rgs6jDTgWXhBvOjl+Ws6G7Yc1D5Cxael2YwGC7IE5zZo3xL6RNLqiz9yQ8lEYYAjOoWzZoK8nhcCqnjEmMj8EhMeMWBZZeLSjDjm8M13nfsRj3Az84XXTsBuDMR8LH3CYicQGkDuPiJQIfHzQ4ANAT6yr6dDbDIWzBYkNmBZWsPnpN1T7rQDjcDmiCw2I7umyWFwLvtpIZBLOwkdyW3AZxGC8QnAX1mYP0v5/Hq6kMoKr7dRr/KmaaWtwNY8AwLNcl7cCmEzISFBNR8EaR+F793TbH/hW5ckYriuCxC7qp5AqCPRLWLg/56YNp5oM8MjP3kACZ+/rNZqABqXgxZ8My96H83QwV5DwYLMhMXHYpwvX+N10UYAtD6wWek49MDG9j5aiKwfhJwajtPSiXXK70JfPMy8OmfpD/LS4HEObe+WDVcaKSPQQsAv0DM/OY3bLZxHk/lOyoL1/tj4bBO6H93pEr/AkTugX0syEJaZk6NDX0WVt5jLxiBHW8D+1OBm1fte1F9pPRDnEsj5AqfDwWObrR8vG1/oONQyyZw+juk3R8xj6G0XEDb//tWVpnmawPuQqMQf4SFSCeVsqaCPAkbZFGtpGXmIPnrQxZ78BsE+WLW4x2srwcLRuD0LumHdNl1ha946wfskE8YLsi5bIUKk7b9gaeXW7StN0KL9Ow8fHHgd6w5KG/3yHt/vgeD7rlDpYETORcbZFGtmHaL7Dt5BXtPXQYgFXje37KhxbusylXwYSGxiBucCt3/Rih8RRGABkhLBtoNYCMtco7Sm9WHCkD6enkpEN294qG0zBxMX58l+4RSE7k1TESejMGCbNJpNejWphG6tWlk8xprP2DrBwbijXZzMOCP/0Bz7YKCVxSBwnPSO8NKP8SJHOa7f8i7bvP/AQPeAQBsOHgOSSsPKn4pfYAP4qJDFd9H5GlYvEl2S8vMwfjlGRbv2vJvliHpYBQ6X5+L4zETlT/xqe3Aoa+A7J0s6iTHyVwD/LRM3rV5pwAAM7/JsitUAMDMP3VgTQXVCZyxILsYBRHT12dVW7CWd1NAn4x4fN2zDe79bbb8LoY737r9zyzqJEf47jVgz/uyLzc2aImXVmRgw685dr3cIzFhGNiRuz+obuCMBdklPTtP1vqyCODFjKYwTjoEDF+rfGtqYQ77XZC6flujKFSIAB7K6GVXqNAAGNu9BRaP6KL4XiJPxWBBdsktkl+0llNQjNc3HEZp8x7AwPdR0QNAlltzImnJXBah2hGMwIkfgNXjZN8iAviuvDPO3lD2Un1imuC1AXfh6Bv98I8B7ZXdTOThGCzILkqr2z/d9zvavfYtZp1uLW0p1SvpNFipqJPIHlnrgLdaAcv/BJTflHWLCGC7Jg4vlL+s6KUiDAFIHdYZY7q3hJ8Pf8RS3cMaC7JLXHQoIgwBirbbCSKwaEc2zuU3xXuTDkF3dq/UF+DSEWDHWzU/gdJDoYgAKVR8OVz5bb2WYtRG5Qfl8TAxqusYp8kuOq0GKQNj7Lp3w685SJizDWnXWwMdngSiH5R3Y3ATaTo7eyd3jZA8ghFYN0n5fUGNcCKks+Lb5v/5Hh4mRnUegwXZLTE2AguHdUL9IPlHQ5tcLCrB+OUZSMvMkXHYk0ZqoXz9CjA3Flj2KLBqjPTnW62AbXMYMMiSYAT2pQLF8tvMi7c+jt73OhoFByl6ubHdo/Eou2oSsaU31Z5REDH/hxNYsjsb+TfLar6hkvqBvljwbCfcX7Ibuv+NvPVo5b+St8JGwkRgzzzYPCsysIFUGMptqQRIyx9Vz/eQQRSBReWPYrbxGYTr/VFcLqDgRlm126q1GilUTOtv3wwekafgWSHkdEZBxOvrMvHpvt8V3xthCMAHnf6w7HehvwPo+yawaZq8XxJPLgNiByt+ffIiWeukLcqyjgW7rUAMQnLZWHwrdLX4msbGsz3ZqSnefLwDizSpTmCwIJcoLRfQ7rVvISj8W2VaBEl9tiMSg7PNDnvCmT3SsofcZ3pwKvDgKzxvpC4SjNJymcKZigIxCJ1LFqLcSj27v48WDYJ8caGwpOKxCEMAUgbGsJ6C6hQeQkYu4eejxdju0Vi0I1vRfbeOIMP0DUfxyNRe5lX1inaDiMD22UD6Ii6N1EVn9igKFaa3VcllY62GCgAoKRcw54m74eeju3XQHo88J6oO5+9IddP6x+CFHtFQ+nNXhNRMKz07z/wLwU2UD+LmVWmLITt21i12bEleVP6o1eWPytb8fA7xrRpi0D13IL6V5Qm/RHQbgwU5xLT+MTgyox+e7NRU8b27T1zG2oPnsPfkFRgFsdKuETuwY2fdYNqGfOmI7FsuiyEYXzYJs43P1Hjt9VL+HSKSi0sh5DB+Plq8PaQjeseEWRytXp35W09U/HPFWnbiHLsK8ngMex2gYAeIIAIFCMaLZZOwX4iBIPO9VZcWCs+4IarDOGNBDpcYG4FdU3vhszFdUT9QWc+LCwXFUr8LoYvUCjwwVPkA2LHTe2WukZa8ZIYKAEgu+wv2CrGyQ4UGwMiEaPvHSFTHMFiQU+i0GnRr0wizn+hgzxFkmL4+C8Z2A4EpJ4AHkxU8A+yr0SD399saYNVom1+uOrd1AQ0xvmwyNglxil7m+R7R3E5KpAD/byGnSoyNQOqwTgg3yD/EzKyoU6sDek4Dnloq485bHTubJ9g7XHJXWeuA/40ERMHmJabo+X75YPy59P/wQMl7ikKFBsALPdj4ikgp1liQ0yXGRuCRmHCkZ+cht6gYxy8WYf7WkzXeZ3ZUe/vBgOZTYP1LwM08K1ff+rWSONu8n4VglGouKvfJYL8LzyIYpZoKmU4ITbFPUBYOnugUiVmPd+RMBZEdGCzIJXRaDeJbNQQA7D15RVawCAsJgFEQKwJJWEg3xL18HLpd7wD7PwBu5t++WB8phYrKfSysFfkFNQT6v8tunZ5kx9uKelXkor6ip+csBVHtMFiQy5mOYL9QUGx1z4cGQLghAFevl+CBOT+Y7S6pH+iL0d0eR9LLL98+ht3aTIStNs83rgBfjQQOPw488V/OXri7rHXAtjdlXSqIUl1FutBO1vWhQb54Y3AH9L+b3TSJaoMtvcktpGXmYPzyDABWjyDD8z2i8eGObJubTYP8dHihR0sk9Wpj2bxIbptnHmTm3hS26xZEyC7WnPxwG0x82MrfHSKqIPf3NxcQyS3YKuoMNwRgwTOdsO6XnGo7WNwoNeI/3x/H3a9vwnvfH5caa5nIbfNs6tbJY9jdk4J23eWiFi+WTZIVKl7oEY3Jj9zJUEGkEi6FkNuoWtRpOpMhPTtPdnOt66VG/Of7Y1iyJxuzH+8gHRKltI/FtjeB/YuAR9+VikTJPcj87yiKwMSyiUiroU13gyBfzBwci/5329nVlYisYrAgt1K5qNPEbDeITPk3yjBueQYWDuuERLvOGrkibWfcEQuM+R7wC1T+HKQO006eXHntut8tf7Lao8+f69YCj8SE8yAxIgdhsCC3FxYiv+dFVdPXZ+GRKQ9CF9RQKtRU6mIm8GY4cGc/4JmVdo+D7JS1DmLaVGgqLYGIIqCxkgekYs1QLDAOtvpU4TzqnMgpWGNBbs+0a8QeOQXFSD9TIG0prY1j3wKLHqrdc5AyWesgfjkCopW6iqol56aSmullIyxadSf1bI3Px96PXVN7MVQQOQGDBbk9nVaDlIExSpp4m8ktKpb6VCRMqt1Acn4GvhjBwk4HMwoi9h7PRcHqlyGKosUPKWuzFbbadUcYAvDXR+7kUedETsRgQR7BtGukfpCyQ8yASkspfWbA+ORSlPkE2z+Qw2uB2c2kcypIdWmZOXhgzg94b8kyGMpyYSsLmMLF+2XVt+tOGRjDQEHkZAwW5DESYyPw0/89gr/2boMgv5obWWkgvWONi5ZORE3LzEHnr+uh7bWFeKfsSVwV69k3kNJrUmHnd6/Zdz+ZMQoidh67hCEL92Dc8gzkFBQjDPmy7j0hSu26qy5/1A/ylQp3ufRB5HQs3iSPotNq8FLvO5HUqw3m/3ACi3acxI1Sy6UJ03tU0zvWtMwcjLvVgAvQYp7xcSwwDkaSbjX+6rPK6vR6jfa8D/gEAg9NZcdOOxgFEfN/OIEPtp1ASfntw8S0ENBIky/rOaq2667np8PzPVohqVdrzlQQuQg7b5JHk345HceS3aeRf7Os4vGISjsAjIKIbrO34EJhidXnSNYtxws+G+0LFwAQUB94bB47dspk+m+2aMcpi1DYV5uOFN9PEKmxdrDcbaZ23Q+UvAcB2lut3VtY77xKRKqQ+/ubwYK8gvnhZAFmPQr2nryCoYv3VXv/NN1neN7nG/vDBQA8tYwNtapRWi7g1a9/xfpfc8xmKEz6atOR6jsXAMxqK6puLzXtAHlF93d0e3Q0wg2B7ElB5ARyf39zKYS8grXGWiZyGmzNMj6Lg2JrvO87D74ay196svxvJHAxGXjoFS6NVDFrY1a1Z734oBxv+n4EDSx3fVT9/AIa4l9lwzH4qb+whoLIDbF4k7ye3AZb3wpd0bbkE6wr72LRJ0G2HbOB2VF1/rwRoyBi9/HLeHvTETy+YBcWVRMq+mrTsc//RTTUFNU4Y/SvsmF4yn8hBj8zjqGCyE1xKYS8Xk01FtYkavfh374fQq9R3k68gk8gMOgDoMPj9j+Hh6i8FHX68g0s2Z1tVvNii2n5w9pMhTXHHpiLVr1GcdmDyAVYY0FUifmuEHm0EPAfnwV4TLe3drUXEfcAfd4Amid4zRKJURCx58RlfJ3xB05dvoaTl67hWomyJSQflONX/zEIRJn87+/IDUB0d+UDJqJaY7AgqiItMwfJXx9C/o2a30lX1l+7F/N850FX2zfJ+kggcY5H7h4xCiL2nbqCvSev4MSlIvxwOBelRvt+dGghYIJuNV70WYdAjdz/Fhrp+zf5kNeEMyJPw2BBZIVRELHv5BUs338aO45fxvUSeXUQ/bR78IHvfADypuytEaGR+msM+cSjwoW9gcyavtp0zPL9L0I11xTeqfG47xuRt+GuECIrdFoNurVphG5tGlX0U/jP98drvO9bIQGLyk/jBZ8Ndr+2BiJEaIC0ZGReAU6fycaZkhAU39EV8a2b4P6Wrj3PwlQncf7qDRz8Ix+ABjdLjfgq4w9Vnr+/di8W+M5TfqN/iFSrwlBB5BE4Y0F1XlpmDl7+8hdct9LBs6p+2v34t++HCNHcVO31C8UA/Le8P5b7DcHIbq3RolEQwkICcE9UfXy69zS+++0CCkvKcVd4CJ7sHIWE1o1kB5DKSxiAiPiWjdCpeQN8uvc0Nv2Wg6KScrQLN6BFaCC+/OkPRQWuStSqT8iw1UDrXqqPiYiU4VIIkQJGQcS8Lccxb+txGGuoQTTVCLzgswHBGvV+EV8TA7Co/FGcEcORi/pIF9pZnIEBSG2r3xnSEY/EhGPPicv4309nceRCEUL8dejbPgLPdG2OLw78jh3HLyM9Ow83y1y77TVZtwIv+GywL1T4hQDJZ1hXQeQGHBosFixYgLfeegsXLlxAx44dMW/ePMTFWZ4sWJuBEbmCURCx69glvPltFk7kXkd19YlSwFiDF3zWqxIwqnaYzBEbYIexA24iAL+LYfjE2AfllVYv/X20VjtYuhMflOOo/0hoIdoXLJ5cJh15T0Qu57Bg8cUXX2DEiBFYuHAhunbtirlz5+J///sfjh49irCwMNUGRuRqppqDTb/lYFXGORQVl1u9TgsBK33/hS7aY7XbllqFtVbW3xjj8FL5JKszGe5GCwEpPksx0ud7+54gYRLQZ4a6gyIiuzksWHTt2hVdunTB/PlShbwgCIiKisLEiRORnJys2sCI3IkpZHy48yS2HblktYvkNN1n+IvPRug0jl1dLBV1mFg2EZsEebOErjC03k94TfMRgsrzld/sr5cOdeO5K0RuxSHBorS0FEFBQfjqq68wePDgisdHjhyJ/Px8rF271uKekpISlJTcniYuLCxEVFQUgwV5rNJyAZ/uPY3TV27gzJXrOHA6DzfLpCUJH5RjhO47PKA9hM7a4zBobqj++qb/Y8eVTXa7cNGrXWPMCPoSkVmLoXjyxq8ekPAS0OPvrKkgckMO2W56+fJlGI1GNGnSxOzxJk2a4MiRI1bvmTVrFqZPn67kZYjcmp+PFmO6t6z43LRtdcHWEyg1+uBjY398bOwPLQTEaY8gDPm4BD3m+b6PRor7N1jSaKRwkeK7DIVlQWiMwmqLPR3N30eLRztEYNYTd8PvyDrgq8XKn6T7K0DPZAYKIi+gaMbi/PnzuOOOO7Bnzx7Ex8dXPP7KK69g+/bt2L9/v8U9nLGgusIoiNhz/DJW/fwHsi9fx6nL183qMp4M+BFvie8CsL/JVnVKRB+sM8ZjWvlYsyJPtTzaIRwtGgVDhIj6gb5oFOx/+8hyCED2TuCLZ4FSheEpPgnoO1P18RKRuhwyY9GoUSPodDpcvHjR7PGLFy8iPDzc6j3+/v7w9/dX8jJEHkmn1aB728bo3rYxAPODucJCAhAX3R/ZKwsQfewjh7y+v6YcT/nsxBO6ndgjxOCE2NTqbhKltBpgbPdoTOsfY/2CrHVA2lSg8LyyJ9bogPgJLNAk8jJ2FW/GxcVh3jypg54gCGjWrBmSkpJYvEkkgzFzNYQ1SfAtN39nLwLK6xKsqLqbxChqsLh8AGYbn6n2vnr+OjzfvRWa1g+o6LzZomEQhse3gJ9PpSUWwSjNTmRvB459B+RmKh9kl78AfWcBPn7K7yUil3DodtORI0di0aJFiIuLw9y5c/Hll1/iyJEjFrUXtRkYkVcz/XI+s0tKFFodxIxl0BTlVFyiZtCABjhfLxZ5909F6y6J+HT/WbPOm091biqvo2fWOmD9JODmVfsHFNQI+Psx1lMQeRiHNsiaP39+RYOse+65B++//z66du2q6sCI6hzBCJzZA1y7CGybDVyp+QwTu/gbgLb9gLxsoKQQaNIeuGcY0LKH7V/2ghHY9m9gx+zav/5Ty7iVlMgDsaU3kadLmwbs+8B5r+cXLDWlatAc+OOA9FhoSyAkAvjmb7WbpTBh0ysij8VgQeQNykuB9EXA7/ukPg+xQ4BVo6WZBk8S1Ajo/w7bcxN5MAYLIm+V+TXw1WhXj0KewAbS0keLB1hTQeTh5P7+dv8DB4jIXOzjQNv+rh6FDBpg4PtAywcZKojqEAYLIk809HOpsZS7Co4AhnwCxDzm6pEQkZNxKYTIk5WXAvtTgSMbpc+1PsCZ3YDVY9KcpP3jwBP/5SwFkZdxSOdNInIzPn5At5ekD5PyUmDDS8DBFU4ejBZISOKuD6I6jsGCyNv4+AGDU4E7+wHrXwJu5jn29SI7SXUfcS+wkyYRMVgQea2Yx4B2A4DTu6QunxoAl08AWWugylJJYCgw8D3WURCRGQYLIm+m1Um7Mlo+ePux8lJg30LgyIbbnTdDWwE/fwJUaileQaMD7n8RaPPI7YDS/AEgujvrKIjIAos3iUhiaileeM6882aXsVziICIWbxKRQlqdNAsBAB3/7NqxEJHHYh8LIiIiUg2DBREREamGwYKIiIhUw2BBREREqmGwICIiItUwWBAREZFqGCyIiIhINQwWREREpBoGCyIiIlKN0ztvmjqIFxYWOvuliYiIyE6m39s1nQTi9GBRVFQEAIiKinL2SxMREVEtFRUVwWAw2Py60w8hEwQB58+fR0hICDQajTNfWpHCwkJERUXh7NmzPCxNRfy+qo/fU/Xxe+oY/L6qz5nfU1EUUVRUhMjISGi1tispnD5jodVq0bRpU2e/rN30ej3/B3AAfl/Vx++p+vg9dQx+X9XnrO9pdTMVJizeJCIiItUwWBAREZFqGCxs8Pf3R0pKCvz9/V09FK/C76v6+D1VH7+njsHvq/rc8Xvq9OJNIiIi8l6csSAiIiLVMFgQERGRahgsiIiISDUMFkRERKQaBgsZTp8+jTFjxiA6OhqBgYFo1aoVUlJSUFpa6uqhebSZM2ciISEBQUFBqF+/vquH47EWLFiAFi1aICAgAF27dkV6erqrh+SxduzYgYEDByIyMhIajQZr1qxx9ZA83qxZs9ClSxeEhIQgLCwMgwcPxtGjR109LI+XmpqKu+++u6IxVnx8PL799ltXDwsAg4UsR44cgSAIWLRoEX777Tf85z//wcKFC/Hqq6+6emgerbS0FE899RTGjx/v6qF4rC+++AJ/+9vfkJKSgoyMDHTs2BF9+/ZFbm6uq4fmka5fv46OHTtiwYIFrh6K19i+fTsmTJiAffv2YfPmzSgrK0OfPn1w/fp1Vw/NozVt2hSzZ8/GTz/9hB9//BG9evXCoEGD8Ntvv7l6aNxuaq+33noLqampOHXqlKuH4vGWLl2KyZMnIz8/39VD8Thdu3ZFly5dMH/+fADSWTxRUVGYOHEikpOTXTw6z6bRaLB69WoMHjzY1UPxKpcuXUJYWBi2b9+OHj16uHo4XiU0NBRvvfUWxowZ49JxcMbCTgUFBQgNDXX1MKgOKy0txU8//YTevXtXPKbVatG7d2/s3bvXhSMjsq2goAAA+PNTRUajEStXrsT169cRHx/v6uE4/xAyb3DixAnMmzcPb7/9tquHQnXY5cuXYTQa0aRJE7PHmzRpgiNHjrhoVES2CYKAyZMno1u3boiNjXX1cDzeoUOHEB8fj+LiYgQHB2P16tWIiYlx9bDq9oxFcnIyNBpNtR9Vf0CfO3cOiYmJeOqppzB27FgXjdx92fM9JaK6YcKECcjMzMTKlStdPRSv0LZtWxw8eBD79+/H+PHjMXLkSGRlZbl6WHV7xuLll1/GqFGjqr2mZcuWFf98/vx59OzZEwkJCfjwww8dPDrPpPR7SvZr1KgRdDodLl68aPb4xYsXER4e7qJREVmXlJSEDRs2YMeOHWjatKmrh+MV/Pz80Lp1awBA586dceDAAbz33ntYtGiRS8dVp4NF48aN0bhxY1nXnjt3Dj179kTnzp2xZMkSaLV1erLHJiXfU6odPz8/dO7cGVu2bKkoMBQEAVu2bEFSUpJrB0d0iyiKmDhxIlavXo1t27YhOjra1UPyWoIgoKSkxNXDqNvBQq5z587hoYceQvPmzfH222/j0qVLFV/jO0P7/f7778jLy8Pvv/8Oo9GIgwcPAgBat26N4OBg1w7OQ/ztb3/DyJEjcd999yEuLg5z587F9evXMXr0aFcPzSNdu3YNJ06cqPg8OzsbBw8eRGhoKJo1a+bCkXmuCRMmYMWKFVi7di1CQkJw4cIFAIDBYEBgYKCLR+e5pk2bhn79+qFZs2YoKirCihUrsG3bNmzatMnVQwNEqtGSJUtEAFY/yH4jR460+j3dunWrq4fmUebNmyc2a9ZM9PPzE+Pi4sR9+/a5ekgea+vWrVb/To4cOdLVQ/NYtn52LlmyxNVD82jPPfec2Lx5c9HPz09s3Lix+PDDD4vfffedq4cliqIoso8FERERqYaFAkRERKQaBgsiIiJSDYMFERERqYbBgoiIiFTDYEFERESqYbAgIiIi1TBYEBERkWoYLIiIiEg1DBZERESkGgYLIiIiUg2DBREREamGwYKIiIhU8//maQ/JCima+QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "xs = np.random.normal(size=(128, 1))\n",
    "ys = xs ** 2\n",
    "\n",
    "for _ in range(1000):\n",
    "  params = update(params, xs, ys)\n",
    "\n",
    "plt.scatter(xs, ys)\n",
    "plt.scatter(xs, forward(params, xs), label='Model prediction')\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lNAvmpzdoE9l"
   },
   "source": [
    "## Key paths\n",
    "\n",
    "In a pytree each leaf has a _key path_. A key path for a leaf is a `list` of _keys_, where the length of the list is equal to the depth of the leaf in the pytree . Each _key_ is a [hashable object](https://docs.python.org/3/glossary.html#term-hashable) that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys for `dict`s is different from the type of keys for `tuple`s.\n",
    "\n",
    "For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique.\n",
    "\n",
    "The APIs for working with key paths are:\n",
    "\n",
    "* [`jax.tree_util.tree_flatten_with_path`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_flatten_with_path.html): Works similarly with `jax.tree_util.tree_flatten`, but returns key paths.\n",
    "\n",
    "* [`jax.tree_util.tree_map_with_path`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map_with_path.html): Works similarly with `jax.tree_util.tree_map`, but the function also takes key paths as arguments.\n",
    "\n",
    "* [`jax.tree_util.keystr`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.keystr.html): Given a general key path, returns a reader-friendly string expression.\n",
    "\n",
    "One use case is to print debugging information related to a certain leaf value:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "G6E2YzhvoE9l",
    "outputId": "5aec83c8-e15e-48eb-b2c3-6fa0164344b5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Value of tree[0]: 1\n",
      "Value of tree[1]['k1']: 2\n",
      "Value of tree[1]['k2'][0]: 3\n",
      "Value of tree[1]['k2'][1]: 4\n",
      "Value of tree[2].name: foo\n"
     ]
    }
   ],
   "source": [
    "import collections\n",
    "ATuple = collections.namedtuple(\"ATuple\", ('name'))\n",
    "\n",
    "tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]\n",
    "flattened, _ = jax.tree_util.tree_flatten_with_path(tree)\n",
    "for key_path, value in flattened:\n",
    "    print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zrKqmANgoE9l"
   },
   "source": [
    "To express key paths, JAX provides a few default key types for the built-in pytree node types, namely:\n",
    "\n",
    "*  `SequenceKey(idx: int)`: for lists and tuples.\n",
    "*  `DictKey(key: Hashable)`: for dictionaries.\n",
    "*  `GetAttrKey(name: str)`: for `namedtuple`s and preferably custom pytree nodes (more in the next section)\n",
    "\n",
    "You are free to define your own key types for your own custom nodes. They will work with `jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ohDq0kGuoE9l",
    "outputId": "9b8ff3ec-3461-482e-ff27-30dc2a7e68c9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Key path of tree[0]: (SequenceKey(idx=0),)\n",
      "Key path of tree[1]['k1']: (SequenceKey(idx=1), DictKey(key='k1'))\n",
      "Key path of tree[1]['k2'][0]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=0))\n",
      "Key path of tree[1]['k2'][1]: (SequenceKey(idx=1), DictKey(key='k2'), SequenceKey(idx=1))\n",
      "Key path of tree[2].name: (SequenceKey(idx=2), GetAttrKey(name='name'))\n"
     ]
    }
   ],
   "source": [
    "for key_path, _ in flattened:\n",
    "    print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sBxOB21YNEDA"
   },
   "source": [
    "## Custom pytree nodes\n",
    "\n",
    "So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CK8LN2PRFnQf"
   },
   "outputs": [],
   "source": [
    "class MyContainer:\n",
    "  \"\"\"A named container.\"\"\"\n",
    "\n",
    "  def __init__(self, name: str, a: int, b: int, c: int):\n",
    "    self.name = name\n",
    "    self.a = a\n",
    "    self.b = b\n",
    "    self.c = c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OPGe2R7ZOXCT",
    "outputId": "40db1f41-9df8-4dea-972a-6a7bc44a49c6"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<__main__.MyContainer at 0x121ae9ac0>, <__main__.MyContainer at 0x1233f9910>]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_util.tree_leaves([\n",
    "    MyContainer('Alice', 1, 2, 3),\n",
    "    MyContainer('Bob', 4, 5, 6)\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vk4vucGXPADj"
   },
   "source": [
    "Accordingly, if we try to use a tree map expecting our leaves to be the elements inside the container, we will get an error:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vIr9_JOIOku7",
    "outputId": "dadc9c15-4a10-4fac-e70d-f23e7085cf74"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TypeError: unsupported operand type(s) for +: 'MyContainer' and 'int'\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    jax.tree_map(lambda x: x + 1, [\n",
    "        MyContainer('Alice', 1, 2, 3),\n",
    "        MyContainer('Bob', 4, 5, 6)\n",
    "    ])\n",
    "except TypeError as e:\n",
    "    print(f'TypeError: {e}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nAZ4FR2lPN51",
    "tags": [
     "raises-exception"
    ]
   },
   "source": [
    "To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2RR5cDFvoE9m",
    "outputId": "94745373-abe4-4bca-967c-4133e8027c30"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 2, 3, 4, 5, 6]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from typing import Iterable\n",
    "\n",
    "def flatten_MyContainer(container) -> tuple[Iterable[int], str]:\n",
    "  \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n",
    "  flat_contents = [container.a, container.b, container.c]\n",
    "\n",
    "  # we don't want the name to appear as a child, so it is auxiliary data.\n",
    "  # auxiliary data is usually a description of the structure of a node,\n",
    "  # e.g., the keys of a dict -- anything that isn't a node's children.\n",
    "  aux_data = container.name\n",
    "  return flat_contents, aux_data\n",
    "\n",
    "def unflatten_MyContainer(\n",
    "    aux_data: str, flat_contents: Iterable[int]) -> MyContainer:\n",
    "  \"\"\"Converts aux data and the flat contents into a MyContainer.\"\"\"\n",
    "  return MyContainer(aux_data, *flat_contents)\n",
    "\n",
    "jax.tree_util.register_pytree_node(\n",
    "    MyContainer, flatten_MyContainer, unflatten_MyContainer)\n",
    "\n",
    "jax.tree_util.tree_leaves([\n",
    "    MyContainer('Alice', 1, 2, 3),\n",
    "    MyContainer('Bob', 4, 5, 6)\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JXaEe76ZoE9m"
   },
   "source": [
    "Alternatively, using the key path API mentioned above, you can register this container with its keys in mind by defining how the keys should look like for each flattened-out value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "D_juQx-2OybX",
    "outputId": "ee2cf4ad-ec21-4636-c9c5-2c64b81429bb"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 2, 3, 4, 5, 6]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class MyKeyPathContainer(MyContainer):\n",
    "  pass\n",
    "\n",
    "def flatten_with_keys_MyKeyPathContainer(container) -> tuple[Iterable[int], str]:\n",
    "  \"\"\"Returns an iterable over container contents, and aux data.\"\"\"\n",
    "\n",
    "  # GetAttrKey is a common way to express an attribute key. Users are free\n",
    "  # to pick any other expression that fits their use cases the best.\n",
    "  flat_contents = [(jax.tree_util.GetAttrKey('a'), container.a),\n",
    "                   (jax.tree_util.GetAttrKey('b'), container.b),\n",
    "                   (jax.tree_util.GetAttrKey('c'), container.c)]\n",
    "\n",
    "  # we don't want the name to appear as a child, so it is auxiliary data.\n",
    "  # auxiliary data is usually a description of the structure of a node,\n",
    "  # e.g., the keys of a dict -- anything that isn't a node's children.\n",
    "  aux_data = container.name\n",
    "  return flat_contents, aux_data\n",
    "\n",
    "def unflatten_MyKeyPathContainer(\n",
    "    aux_data: str, flat_contents: Iterable[int]) -> MyKeyPathContainer:\n",
    "  \"\"\"Converts aux data and the flat contents into a MyContainer.\"\"\"\n",
    "  return MyKeyPathContainer(aux_data, *flat_contents)\n",
    "\n",
    "jax.tree_util.register_pytree_with_keys(\n",
    "    MyKeyPathContainer, flatten_with_keys_MyKeyPathContainer, unflatten_MyKeyPathContainer)\n",
    "\n",
    "jax.tree_util.tree_leaves([\n",
    "    MyKeyPathContainer('Alice', 1, 2, 3),\n",
    "    MyKeyPathContainer('Bob', 4, 5, 6)\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HPX23W4zoE9m"
   },
   "source": [
    "`register_pytree_with_keys` is an extended API of `register_pytree_node`, and containers registered in either way can freely use all the `tree_util` utilities without error.\n",
    "\n",
    "When a container registered with `register_pytree_node` uses `.*_with_path` APIs, the keys being returned will be a series of \"flat index\" fallbacks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "E1BwD2aZoE9m",
    "outputId": "4fe12b06-aef4-426a-a732-891affa63842"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MyContainer container[<flat index 0>]: 1\n",
      "MyContainer container[<flat index 1>]: 2\n",
      "MyContainer container[<flat index 2>]: 3\n",
      "MyKeyPathContainer container.a: 1\n",
      "MyKeyPathContainer container.b: 2\n",
      "MyKeyPathContainer container.c: 3\n"
     ]
    }
   ],
   "source": [
    "flattened, _ = jax.tree_util.tree_flatten_with_path(MyContainer('Alice', 1, 2, 3))\n",
    "for key_path, value in flattened:\n",
    "    print(f'MyContainer container{jax.tree_util.keystr(key_path)}: {value}')\n",
    "\n",
    "flattened, _ = jax.tree_util.tree_flatten_with_path(MyKeyPathContainer('Alice', 1, 2, 3))\n",
    "for key_path, value in flattened:\n",
    "    print(f'MyKeyPathContainer container{jax.tree_util.keystr(key_path)}: {value}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JgnAp7fFShEB"
   },
   "source": [
    "Modern Python comes equipped with helpful tools to make defining containers easier. Some of these will work with JAX out-of-the-box, but others require more care. For instance, a `NamedTuple` subclass doesn't need to be registered to be considered a pytree node type:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8DNoLABtO0fr",
    "outputId": "9a448508-43eb-4450-bfaf-eeeb59a9e349"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Alice', 1, 2, 3, 'Bob', 4, 5, 6]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from typing import NamedTuple, Any\n",
    "\n",
    "class MyOtherContainer(NamedTuple):\n",
    "  name: str\n",
    "  a: Any\n",
    "  b: Any\n",
    "  c: Any\n",
    "\n",
    "# NamedTuple subclasses are handled as pytree nodes, so\n",
    "# this will work out-of-the-box:\n",
    "jax.tree_util.tree_leaves([\n",
    "    MyOtherContainer('Alice', 1, 2, 3),\n",
    "    MyOtherContainer('Bob', 4, 5, 6)\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TVdtzJDVTZb6"
   },
   "source": [
    "Notice that the `name` field now appears as a leaf, as all tuple elements are children. That's the price we pay for not having to register the class the hard way."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wDbVszv-oE9n"
   },
   "source": [
    "One shortcut is to use `jax.tree_util.register_static` to register a type as being a node without children:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "executionInfo": {
     "elapsed": 59,
     "status": "ok",
     "timestamp": 1692698060536,
     "user": {
      "displayName": "",
      "userId": ""
     },
     "user_tz": -60
    },
    "id": "Rclc079ioE9n",
    "outputId": "6b6a4402-8fc1-409c-b6da-88568a612e1b"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 2, 3, 4, 5, 6]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from typing import NamedTuple, Any\n",
    "\n",
    "@jax.tree_util.register_static\n",
    "class StaticStr(str):\n",
    "  pass\n",
    "\n",
    "\n",
    "class YetAnotherContainer(NamedTuple):\n",
    "  name: StaticStr\n",
    "  a: Any\n",
    "  b: Any\n",
    "  c: Any\n",
    "\n",
    "\n",
    "# NamedTuple subclasses are handled as pytree nodes, so\n",
    "# this will work out-of-the-box:\n",
    "jax.tree_util.tree_leaves([\n",
    "    YetAnotherContainer(StaticStr('Alice'), 1, 2, 3),\n",
    "    YetAnotherContainer(StaticStr('Bob'), 4, 5, 6)\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kNsTszcEEHD0"
   },
   "source": [
    "## Common pytree gotchas and patterns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0ki-JDENzyL7"
   },
   "source": [
    "### Gotchas\n",
    "#### Mistaking nodes for leaves\n",
    "A common problem to look out for is accidentally introducing tree nodes instead of leaves:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "N-th4jOAGJlM",
    "outputId": "23eed14d-d383-4d88-d6f9-02bac06020df"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(Array([1., 1.], dtype=float32), Array([1., 1., 1.], dtype=float32)),\n",
       " (Array([1., 1., 1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]\n",
    "\n",
    "# Try to make another tree with ones instead of zeros\n",
    "shapes = jax.tree_map(lambda x: x.shape, a_tree)\n",
    "jax.tree_map(jnp.ones, shapes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "q8d4y-hfHTWh"
   },
   "source": [
    "What happened is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`.\n",
    "\n",
    "The solution will depend on the specifics, but there are two broadly applicable options:\n",
    "* rewrite the code to avoid the intermediate `tree_map`.\n",
    "* convert the tuple into an `np.array` or `jnp.array`, which makes the entire\n",
    "sequence a leaf."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4OKlbFlEIda-"
   },
   "source": [
    "#### Handling of None\n",
    "`jax.tree_utils` treats `None` as a node without children, not as a leaf:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gIwlwo2MJcEC",
    "outputId": "1e59f323-a7b7-42be-8603-afa4693c00cc"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_util.tree_leaves([None, None, None])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pwNz-rp1JvW4"
   },
   "source": [
    "### Patterns\n",
    "#### Transposing trees\n",
    "\n",
    "If you would like to transpose a pytree, i.e. turn a list of trees into a tree of lists, you can do so using `jax.tree_map`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "UExN7-G7qU-F",
    "outputId": "fd049086-ef37-44db-8e2c-9f1bd9fad950"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'obs': [3, 4], 't': [1, 2]}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def tree_transpose(list_of_trees):\n",
    "  \"\"\"Convert a list of trees of identical structure into a single tree of lists.\"\"\"\n",
    "  return jax.tree_map(lambda *xs: list(xs), *list_of_trees)\n",
    "\n",
    "\n",
    "# Convert a dataset from row-major to column-major:\n",
    "episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]\n",
    "tree_transpose(episode_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ao6R2ffm2CF4"
   },
   "source": [
    "For more complicated transposes, JAX provides `jax.tree_transpose`, which is more verbose, but allows you specify the structure of the inner and outer Pytree for more flexibility:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bZvVwxshz1D3",
    "outputId": "a0314dc8-4267-41e6-a763-931d40433c26"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/v7/zgx9fwms2fnd2d8pwb2c3r2400gbn1/T/ipykernel_94597/112383129.py:2: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.\n",
      "  outer_treedef = jax.tree_structure([0 for e in episode_steps]),\n",
      "/var/folders/v7/zgx9fwms2fnd2d8pwb2c3r2400gbn1/T/ipykernel_94597/112383129.py:3: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.\n",
      "  inner_treedef = jax.tree_structure(episode_steps[0]),\n",
      "/var/folders/v7/zgx9fwms2fnd2d8pwb2c3r2400gbn1/T/ipykernel_94597/112383129.py:1: FutureWarning: jax.tree_transpose is deprecated, and will be removed in a future release. Use jax.tree_util.tree_transpose instead.\n",
      "  jax.tree_transpose(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'obs': [3, 4], 't': [1, 2]}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_transpose(\n",
    "  outer_treedef = jax.tree_structure([0 for e in episode_steps]),\n",
    "  inner_treedef = jax.tree_structure(episode_steps[0]),\n",
    "  pytree_to_transpose = episode_steps\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KlYA2R6N2h_8"
   },
   "source": [
    "## More Information\n",
    "\n",
    "For more information on pytrees in JAX and the operations that are available, see the [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) section in the JAX documentation."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "last_runtime": {
    "build_target": "//learning/deepmind/dm_python:dm_notebook3",
    "kind": "private"
   },
   "name": "jax101-pytrees",
   "provenance": []
  },
  "jupytext": {
   "formats": "ipynb,md:myst"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
